In [1]:
from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
import cv2
import os
import random
/homes/e34960/anaconda3/envs/patchdiff/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
In [2]:
gt_images_folder = "/project/trinity/datasets/apricot/pub/apricot-mask/data_mask_v2"
gt_images_names = os.listdir(gt_images_folder)

print(len(gt_images_names))
873
In [3]:
def compute_vectorized_iou(given_mask, masks):
    given_mask_expanded = np.expand_dims(given_mask, axis=0)
    masks_expanded = np.stack(masks)
    intersection = np.logical_and(masks_expanded > 0, given_mask_expanded > 0)
    union = np.logical_or(masks_expanded > 0, given_mask_expanded > 0)
    intersection_sum = np.sum(intersection, axis=(1, 2))
    union_sum = np.sum(union, axis=(1, 2))
    iou_values = intersection_sum / union_sum
    return iou_values

def getRandomImages(n=1):
    random_numbers = random.sample(range(0, len(gt_images_names)-1), n)

    imgs = []
    img_masks = []
    for i in range(n):
        # idx = random.randint(0,len(gt_images_names)-1)
        img_info = torch.load(os.path.join(gt_images_folder, gt_images_names[random_numbers[i]]))
        img = np.squeeze(img_info['Image'])
        img = np.uint8(img * 255.0)
        img = cv2.resize(img, (512,512), interpolation = cv2.INTER_AREA)
        imgs.append(img)

        img_mask = np.squeeze(img_info['Mask'])
        img_mask = np.uint8(img_mask * 255.0)
        img_mask = cv2.resize(img_mask, (512,512), interpolation = cv2.INTER_AREA)
        img_masks.append(img_mask)

    return imgs, img_masks
In [4]:
sam = sam_model_registry["vit_h"](checkpoint="/project/trinity/pretrained_models/sam_facebook/sam_vit_h_4b8939.pth")
sam.to(device='cuda')
mask_generator = SamAutomaticMaskGenerator(sam)
In [5]:
num_images = 20
img_list, mask_list = getRandomImages(num_images)

patch_indices = []

for i in range(num_images):
    masks = mask_generator.generate(img_list[i])
    all_masks = []
    for z, meta_info in enumerate(masks):
        all_masks.append(meta_info['segmentation'])

    iou_vals = compute_vectorized_iou(mask_list[i], all_masks)
    patch_mask_idx = np.argmax(iou_vals)

    max_frequencies_list = []
    fft_list = []
    masked_img_list = []

    for j in range(len(masks)):
        # fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(12, 20))

        current_mask = masks[j]['segmentation']

        mask_3d = np.repeat(current_mask[:, :, np.newaxis], 3, axis=2)
        masked_img = img_list[i] * mask_3d

        gray_scale = cv2.cvtColor(masked_img, cv2.COLOR_BGR2GRAY)

        fft = np.fft.fft2(gray_scale)
        magnitude_spectrum = np.abs(fft)

        fft_list.append(np.log(np.fft.fftshift(magnitude_spectrum) + 1))
        masked_img_list.append(masked_img)
        
        max_frequencies_list.append(magnitude_spectrum[255][255])   # Highest Frequency is N-1, N-1
    

    fft_list = np.stack(fft_list, axis=-1)
    fft_list = 255.0 * (np.max(fft_list) - fft_list) / (np.max(fft_list) - np.min(fft_list))
    
    # Highest Frequency for the patch mask
    patch_mask_freq = max_frequencies_list[patch_mask_idx]

    max_frequencies_list = np.array(max_frequencies_list)
    sorted_max_frequencies_list = np.sort(max_frequencies_list)[::-1]
    sorted_indices = np.argsort(max_frequencies_list)[::-1]

    patch_indices.append(np.where(sorted_max_frequencies_list == patch_mask_freq)[0][0])
    plt.imshow(img_list[i])
    plt.axis('off')
    plt.show()

    ###### Plot top 4 frequencies and the adversarial patch frequency (magnitude spectrum)
    fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(20, 9))
    for k in range(5):
        if k != 4:
            axes[0][k].imshow(masked_img_list[sorted_indices[k]], cmap='gray')
            axes[0][k].set_title('Top {} frequency component'.format(k+1))
            axes[0][k].axis('off')

            axes[1][k].imshow(fft_list[:,:,sorted_indices[k]], cmap='gray')
            axes[1][k].set_title('Shifted FFT')
            axes[1][k].axis('off')

        else:
            axes[0][k].imshow(masked_img_list[patch_mask_idx], cmap='gray')
            axes[0][k].set_title('Adversarial patch')
            axes[0][k].axis('off')

            axes[1][k].imshow(fft_list[:,:,patch_mask_idx], cmap='gray')
            axes[1][k].set_title('Shifted FFT')
            axes[1][k].axis('off')
    
    plt.tight_layout()
    plt.show()
In [7]:
hist, bins = np.histogram(patch_indices, bins=np.arange(np.max(patch_indices) + 1))

plt.figure(figsize=(14,10))
plt.rc('font', size=8)

percentage = hist / len(patch_indices) * 100

# Plot the percentage-wise histogram
plt.bar(bins[:-1], percentage, align='center', width=0.8)
plt.xlabel('Value')
plt.ylabel('Percentage')
plt.title('Index of the adversarial patch among the highest frequencies for {} images'.format(num_images))

for i in range(len(percentage)):
    if percentage[i] != 0.0:
        plt.text(bins[i], percentage[i], f'{percentage[i]:.1f}', ha='center')

plt.show()